import unittest
import os
import textwrap
import contextlib
import importlib
import sys
import socket
import threading
import time
from contextlib import contextmanager
from asyncio import staggered, taskgroups, base_events, tasks
from unittest.mock import ANY
from test.support import (
    os_helper,
    SHORT_TIMEOUT,
    busy_retry,
    requires_gil_enabled,
)
from test.support.script_helper import make_script
from test.support.socket_helper import find_unused_port

import subprocess

# Profiling mode constants
PROFILING_MODE_WALL = 0
PROFILING_MODE_CPU = 1
PROFILING_MODE_GIL = 2
PROFILING_MODE_ALL = 3

# Thread status flags
THREAD_STATUS_HAS_GIL = 1 << 0
THREAD_STATUS_ON_CPU = 1 << 1
THREAD_STATUS_UNKNOWN = 1 << 2

# Maximum number of retry attempts for operations that may fail transiently
MAX_TRIES = 10

try:
    from concurrent import interpreters
except ImportError:
    interpreters = None

PROCESS_VM_READV_SUPPORTED = False

try:
    from _remote_debugging import PROCESS_VM_READV_SUPPORTED
    from _remote_debugging import RemoteUnwinder
    from _remote_debugging import FrameInfo, CoroInfo, TaskInfo
except ImportError:
    raise unittest.SkipTest(
        "Test only runs when _remote_debugging is available"
    )


# ============================================================================
# Module-level helper functions
# ============================================================================


def _make_test_script(script_dir, script_basename, source):
    to_return = make_script(script_dir, script_basename, source)
    importlib.invalidate_caches()
    return to_return


def _create_server_socket(port, backlog=1):
    """Create and configure a server socket for test communication."""
    server_socket = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
    server_socket.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    server_socket.bind(("localhost", port))
    server_socket.settimeout(SHORT_TIMEOUT)
    server_socket.listen(backlog)
    return server_socket


def _wait_for_signal(sock, expected_signals, timeout=SHORT_TIMEOUT):
    """
    Wait for expected signal(s) from a socket with proper timeout and EOF handling.

    Args:
        sock: Connected socket to read from
        expected_signals: Single bytes object or list of bytes objects to wait for
        timeout: Socket timeout in seconds

    Returns:
        bytes: Complete accumulated response buffer

    Raises:
        RuntimeError: If connection closed before signal received or timeout
    """
    if isinstance(expected_signals, bytes):
        expected_signals = [expected_signals]

    sock.settimeout(timeout)
    buffer = b""

    while True:
        # Check if all expected signals are in buffer
        if all(sig in buffer for sig in expected_signals):
            return buffer

        try:
            chunk = sock.recv(4096)
            if not chunk:
                # EOF - connection closed
                raise RuntimeError(
                    f"Connection closed before receiving expected signals. "
                    f"Expected: {expected_signals}, Got: {buffer[-200:]!r}"
                )
            buffer += chunk
        except socket.timeout:
            raise RuntimeError(
                f"Timeout waiting for signals. "
                f"Expected: {expected_signals}, Got: {buffer[-200:]!r}"
            )


def _wait_for_n_signals(sock, signal_pattern, count, timeout=SHORT_TIMEOUT):
    """
    Wait for N occurrences of a signal pattern.

    Args:
        sock: Connected socket to read from
        signal_pattern: bytes pattern to count (e.g., b"ready")
        count: Number of occurrences expected
        timeout: Socket timeout in seconds

    Returns:
        bytes: Complete accumulated response buffer

    Raises:
        RuntimeError: If connection closed or timeout before receiving all signals
    """
    sock.settimeout(timeout)
    buffer = b""
    found_count = 0

    while found_count < count:
        try:
            chunk = sock.recv(4096)
            if not chunk:
                raise RuntimeError(
                    f"Connection closed after {found_count}/{count} signals. "
                    f"Last 200 bytes: {buffer[-200:]!r}"
                )
            buffer += chunk
            # Count occurrences in entire buffer
            found_count = buffer.count(signal_pattern)
        except socket.timeout:
            raise RuntimeError(
                f"Timeout waiting for {count} signals (found {found_count}). "
                f"Last 200 bytes: {buffer[-200:]!r}"
            )

    return buffer


@contextmanager
def _managed_subprocess(args, timeout=SHORT_TIMEOUT):
    """
    Context manager for subprocess lifecycle management.

    Ensures process is properly terminated and cleaned up even on exceptions.
    Uses graceful termination first, then forceful kill if needed.
    """
    p = subprocess.Popen(args)
    try:
        yield p
    finally:
        try:
            p.terminate()
            try:
                p.wait(timeout=timeout)
            except subprocess.TimeoutExpired:
                p.kill()
                try:
                    p.wait(timeout=timeout)
                except subprocess.TimeoutExpired:
                    pass  # Process refuses to die, nothing more we can do
        except OSError:
            pass  # Process already dead


def _cleanup_sockets(*sockets):
    """Safely close multiple sockets, ignoring errors."""
    for sock in sockets:
        if sock is not None:
            try:
                sock.close()
            except OSError:
                pass


# ============================================================================
# Decorators and skip conditions
# ============================================================================

skip_if_not_supported = unittest.skipIf(
    (
        sys.platform != "darwin"
        and sys.platform != "linux"
        and sys.platform != "win32"
    ),
    "Test only runs on Linux, Windows and MacOS",
)


def requires_subinterpreters(meth):
    """Decorator to skip a test if subinterpreters are not supported."""
    return unittest.skipIf(interpreters is None, "subinterpreters required")(
        meth
    )


# ============================================================================
# Simple wrapper functions for RemoteUnwinder
# ============================================================================

def get_stack_trace(pid):
    for _ in busy_retry(SHORT_TIMEOUT):
        try:
            unwinder = RemoteUnwinder(pid, all_threads=True, debug=True)
            return unwinder.get_stack_trace()
        except RuntimeError as e:
            continue
    raise RuntimeError("Failed to get stack trace after retries")


def get_async_stack_trace(pid):
    for _ in busy_retry(SHORT_TIMEOUT):
        try:
            unwinder = RemoteUnwinder(pid, debug=True)
            return unwinder.get_async_stack_trace()
        except RuntimeError as e:
            continue
    raise RuntimeError("Failed to get async stack trace after retries")


def get_all_awaited_by(pid):
    for _ in busy_retry(SHORT_TIMEOUT):
        try:
            unwinder = RemoteUnwinder(pid, debug=True)
            return unwinder.get_all_awaited_by()
        except RuntimeError as e:
            continue
    raise RuntimeError("Failed to get all awaited_by after retries")


# ============================================================================
# Base test class with shared infrastructure
# ============================================================================


class RemoteInspectionTestBase(unittest.TestCase):
    """Base class for remote inspection tests with common helpers."""

    maxDiff = None

    def _run_script_and_get_trace(
        self,
        script,
        trace_func,
        wait_for_signals=None,
        port=None,
        backlog=1,
    ):
        """
        Common pattern: run a script, wait for signals, get trace.

        Args:
            script: Script content (will be formatted with port if {port} present)
            trace_func: Function to call with pid to get trace (e.g., get_stack_trace)
            wait_for_signals: Signal(s) to wait for before getting trace
            port: Port to use (auto-selected if None)
            backlog: Socket listen backlog

        Returns:
            tuple: (trace_result, script_name)
        """
        if port is None:
            port = find_unused_port()

        # Format script with port if needed
        if "{port}" in script or "{{port}}" in script:
            script = script.replace("{{port}}", "{port}").format(port=port)

        with os_helper.temp_dir() as work_dir:
            script_dir = os.path.join(work_dir, "script_pkg")
            os.mkdir(script_dir)

            server_socket = _create_server_socket(port, backlog)
            script_name = _make_test_script(script_dir, "script", script)
            client_socket = None

            try:
                with _managed_subprocess([sys.executable, script_name]) as p:
                    client_socket, _ = server_socket.accept()
                    server_socket.close()
                    server_socket = None

                    if wait_for_signals:
                        _wait_for_signal(client_socket, wait_for_signals)

                    try:
                        trace = trace_func(p.pid)
                    except PermissionError:
                        self.skipTest(
                            "Insufficient permissions to read the stack trace"
                        )
                    return trace, script_name
            finally:
                _cleanup_sockets(client_socket, server_socket)

    def _find_frame_in_trace(self, stack_trace, predicate):
        """
        Find a frame matching predicate in stack trace.

        Args:
            stack_trace: List of InterpreterInfo objects
            predicate: Function(frame) -> bool

        Returns:
            FrameInfo or None
        """
        for interpreter_info in stack_trace:
            for thread_info in interpreter_info.threads:
                for frame in thread_info.frame_info:
                    if predicate(frame):
                        return frame
        return None

    def _find_thread_by_id(self, stack_trace, thread_id):
        """Find a thread by its native thread ID."""
        for interpreter_info in stack_trace:
            for thread_info in interpreter_info.threads:
                if thread_info.thread_id == thread_id:
                    return thread_info
        return None

    def _find_thread_with_frame(self, stack_trace, frame_predicate):
        """Find a thread containing a frame matching predicate."""
        for interpreter_info in stack_trace:
            for thread_info in interpreter_info.threads:
                for frame in thread_info.frame_info:
                    if frame_predicate(frame):
                        return thread_info
        return None

    def _get_thread_statuses(self, stack_trace):
        """Extract thread_id -> status mapping from stack trace."""
        statuses = {}
        for interpreter_info in stack_trace:
            for thread_info in interpreter_info.threads:
                statuses[thread_info.thread_id] = thread_info.status
        return statuses

    def _get_task_id_map(self, stack_trace):
        """Create task_id -> task mapping from async stack trace."""
        return {task.task_id: task for task in stack_trace[0].awaited_by}

    def _get_awaited_by_relationships(self, stack_trace):
        """Extract task name to awaited_by set mapping."""
        id_to_task = self._get_task_id_map(stack_trace)
        return {
            task.task_name: set(
                id_to_task[awaited.task_name].task_name
                for awaited in task.awaited_by
            )
            for task in stack_trace[0].awaited_by
        }

    def _extract_coroutine_stacks(self, stack_trace):
        """Extract and format coroutine stacks from tasks."""
        return {
            task.task_name: sorted(
                tuple(tuple(frame) for frame in coro.call_stack)
                for coro in task.coroutine_stack
            )
            for task in stack_trace[0].awaited_by
        }


# ============================================================================
# Test classes
# ============================================================================


class TestGetStackTrace(RemoteInspectionTestBase):
    @skip_if_not_supported
    @unittest.skipIf(
        sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
        "Test only runs on Linux with process_vm_readv support",
    )
    def test_remote_stack_trace(self):
        port = find_unused_port()
        script = textwrap.dedent(
            f"""\
            import time, sys, socket, threading

            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.connect(('localhost', {port}))

            def bar():
                for x in range(100):
                    if x == 50:
                        baz()

            def baz():
                foo()

            def foo():
                sock.sendall(b"ready:thread\\n"); time.sleep(10_000)

            t = threading.Thread(target=bar)
            t.start()
            sock.sendall(b"ready:main\\n"); t.join()
            """
        )

        with os_helper.temp_dir() as work_dir:
            script_dir = os.path.join(work_dir, "script_pkg")
            os.mkdir(script_dir)

            server_socket = _create_server_socket(port)
            script_name = _make_test_script(script_dir, "script", script)
            client_socket = None

            try:
                with _managed_subprocess([sys.executable, script_name]) as p:
                    client_socket, _ = server_socket.accept()
                    server_socket.close()
                    server_socket = None

                    _wait_for_signal(
                        client_socket, [b"ready:main", b"ready:thread"]
                    )

                    try:
                        stack_trace = get_stack_trace(p.pid)
                    except PermissionError:
                        self.skipTest(
                            "Insufficient permissions to read the stack trace"
                        )

                    thread_expected_stack_trace = [
                        FrameInfo([script_name, 15, "foo"]),
                        FrameInfo([script_name, 12, "baz"]),
                        FrameInfo([script_name, 9, "bar"]),
                        FrameInfo([threading.__file__, ANY, "Thread.run"]),
                        FrameInfo(
                            [
                                threading.__file__,
                                ANY,
                                "Thread._bootstrap_inner",
                            ]
                        ),
                        FrameInfo(
                            [threading.__file__, ANY, "Thread._bootstrap"]
                        ),
                    ]

                    # Find expected thread stack
                    found_thread = self._find_thread_with_frame(
                        stack_trace,
                        lambda f: f.funcname == "foo" and f.lineno == 15,
                    )
                    self.assertIsNotNone(
                        found_thread, "Expected thread stack trace not found"
                    )
                    self.assertEqual(
                        found_thread.frame_info, thread_expected_stack_trace
                    )

                    # Check main thread
                    main_frame = FrameInfo([script_name, 19, "<module>"])
                    found_main = self._find_frame_in_trace(
                        stack_trace, lambda f: f == main_frame
                    )
                    self.assertIsNotNone(
                        found_main, "Main thread stack trace not found"
                    )
            finally:
                _cleanup_sockets(client_socket, server_socket)

    @skip_if_not_supported
    @unittest.skipIf(
        sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
        "Test only runs on Linux with process_vm_readv support",
    )
    def test_async_remote_stack_trace(self):
        port = find_unused_port()
        script = textwrap.dedent(
            f"""\
            import asyncio
            import time
            import sys
            import socket

            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.connect(('localhost', {port}))

            def c5():
                sock.sendall(b"ready"); time.sleep(10_000)

            async def c4():
                await asyncio.sleep(0)
                c5()

            async def c3():
                await c4()

            async def c2():
                await c3()

            async def c1(task):
                await task

            async def main():
                async with asyncio.TaskGroup() as tg:
                    task = tg.create_task(c2(), name="c2_root")
                    tg.create_task(c1(task), name="sub_main_1")
                    tg.create_task(c1(task), name="sub_main_2")

            def new_eager_loop():
                loop = asyncio.new_event_loop()
                eager_task_factory = asyncio.create_eager_task_factory(
                    asyncio.Task)
                loop.set_task_factory(eager_task_factory)
                return loop

            asyncio.run(main(), loop_factory={{TASK_FACTORY}})
            """
        )

        for task_factory_variant in "asyncio.new_event_loop", "new_eager_loop":
            with (
                self.subTest(task_factory_variant=task_factory_variant),
                os_helper.temp_dir() as work_dir,
            ):
                script_dir = os.path.join(work_dir, "script_pkg")
                os.mkdir(script_dir)

                server_socket = _create_server_socket(port)
                script_name = _make_test_script(
                    script_dir,
                    "script",
                    script.format(TASK_FACTORY=task_factory_variant),
                )
                client_socket = None

                try:
                    with _managed_subprocess(
                        [sys.executable, script_name]
                    ) as p:
                        client_socket, _ = server_socket.accept()
                        server_socket.close()
                        server_socket = None

                        response = _wait_for_signal(client_socket, b"ready")
                        self.assertIn(b"ready", response)

                        try:
                            stack_trace = get_async_stack_trace(p.pid)
                        except PermissionError:
                            self.skipTest(
                                "Insufficient permissions to read the stack trace"
                            )

                        # Check all tasks are present
                        tasks_names = [
                            task.task_name
                            for task in stack_trace[0].awaited_by
                        ]
                        for task_name in [
                            "c2_root",
                            "sub_main_1",
                            "sub_main_2",
                        ]:
                            self.assertIn(task_name, tasks_names)

                        # Check awaited_by relationships
                        relationships = self._get_awaited_by_relationships(
                            stack_trace
                        )
                        self.assertEqual(
                            relationships,
                            {
                                "c2_root": {
                                    "Task-1",
                                    "sub_main_1",
                                    "sub_main_2",
                                },
                                "Task-1": set(),
                                "sub_main_1": {"Task-1"},
                                "sub_main_2": {"Task-1"},
                            },
                        )

                        # Check coroutine stacks
                        coroutine_stacks = self._extract_coroutine_stacks(
                            stack_trace
                        )
                        self.assertEqual(
                            coroutine_stacks,
                            {
                                "Task-1": [
                                    (
                                        tuple(
                                            [
                                                taskgroups.__file__,
                                                ANY,
                                                "TaskGroup._aexit",
                                            ]
                                        ),
                                        tuple(
                                            [
                                                taskgroups.__file__,
                                                ANY,
                                                "TaskGroup.__aexit__",
                                            ]
                                        ),
                                        tuple([script_name, 26, "main"]),
                                    )
                                ],
                                "c2_root": [
                                    (
                                        tuple([script_name, 10, "c5"]),
                                        tuple([script_name, 14, "c4"]),
                                        tuple([script_name, 17, "c3"]),
                                        tuple([script_name, 20, "c2"]),
                                    )
                                ],
                                "sub_main_1": [
                                    (tuple([script_name, 23, "c1"]),)
                                ],
                                "sub_main_2": [
                                    (tuple([script_name, 23, "c1"]),)
                                ],
                            },
                        )

                        # Check awaited_by coroutine stacks
                        id_to_task = self._get_task_id_map(stack_trace)
                        awaited_by_coroutine_stacks = {
                            task.task_name: sorted(
                                (
                                    id_to_task[coro.task_name].task_name,
                                    tuple(
                                        tuple(frame)
                                        for frame in coro.call_stack
                                    ),
                                )
                                for coro in task.awaited_by
                            )
                            for task in stack_trace[0].awaited_by
                        }
                        self.assertEqual(
                            awaited_by_coroutine_stacks,
                            {
                                "Task-1": [],
                                "c2_root": [
                                    (
                                        "Task-1",
                                        (
                                            tuple(
                                                [
                                                    taskgroups.__file__,
                                                    ANY,
                                                    "TaskGroup._aexit",
                                                ]
                                            ),
                                            tuple(
                                                [
                                                    taskgroups.__file__,
                                                    ANY,
                                                    "TaskGroup.__aexit__",
                                                ]
                                            ),
                                            tuple([script_name, 26, "main"]),
                                        ),
                                    ),
                                    (
                                        "sub_main_1",
                                        (tuple([script_name, 23, "c1"]),),
                                    ),
                                    (
                                        "sub_main_2",
                                        (tuple([script_name, 23, "c1"]),),
                                    ),
                                ],
                                "sub_main_1": [
                                    (
                                        "Task-1",
                                        (
                                            tuple(
                                                [
                                                    taskgroups.__file__,
                                                    ANY,
                                                    "TaskGroup._aexit",
                                                ]
                                            ),
                                            tuple(
                                                [
                                                    taskgroups.__file__,
                                                    ANY,
                                                    "TaskGroup.__aexit__",
                                                ]
                                            ),
                                            tuple([script_name, 26, "main"]),
                                        ),
                                    )
                                ],
                                "sub_main_2": [
                                    (
                                        "Task-1",
                                        (
                                            tuple(
                                                [
                                                    taskgroups.__file__,
                                                    ANY,
                                                    "TaskGroup._aexit",
                                                ]
                                            ),
                                            tuple(
                                                [
                                                    taskgroups.__file__,
                                                    ANY,
                                                    "TaskGroup.__aexit__",
                                                ]
                                            ),
                                            tuple([script_name, 26, "main"]),
                                        ),
                                    )
                                ],
                            },
                        )
                finally:
                    _cleanup_sockets(client_socket, server_socket)

    @skip_if_not_supported
    @unittest.skipIf(
        sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
        "Test only runs on Linux with process_vm_readv support",
    )
    def test_asyncgen_remote_stack_trace(self):
        port = find_unused_port()
        script = textwrap.dedent(
            f"""\
            import asyncio
            import time
            import sys
            import socket

            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.connect(('localhost', {port}))

            async def gen_nested_call():
                sock.sendall(b"ready"); time.sleep(10_000)

            async def gen():
                for num in range(2):
                    yield num
                    if num == 1:
                        await gen_nested_call()

            async def main():
                async for el in gen():
                    pass

            asyncio.run(main())
            """
        )

        with os_helper.temp_dir() as work_dir:
            script_dir = os.path.join(work_dir, "script_pkg")
            os.mkdir(script_dir)

            server_socket = _create_server_socket(port)
            script_name = _make_test_script(script_dir, "script", script)
            client_socket = None

            try:
                with _managed_subprocess([sys.executable, script_name]) as p:
                    client_socket, _ = server_socket.accept()
                    server_socket.close()
                    server_socket = None

                    response = _wait_for_signal(client_socket, b"ready")
                    self.assertIn(b"ready", response)

                    try:
                        stack_trace = get_async_stack_trace(p.pid)
                    except PermissionError:
                        self.skipTest(
                            "Insufficient permissions to read the stack trace"
                        )

                    # For this simple asyncgen test, we only expect one task
                    self.assertEqual(len(stack_trace[0].awaited_by), 1)
                    task = stack_trace[0].awaited_by[0]
                    self.assertEqual(task.task_name, "Task-1")

                    # Check the coroutine stack
                    coroutine_stack = sorted(
                        tuple(tuple(frame) for frame in coro.call_stack)
                        for coro in task.coroutine_stack
                    )
                    self.assertEqual(
                        coroutine_stack,
                        [
                            (
                                tuple([script_name, 10, "gen_nested_call"]),
                                tuple([script_name, 16, "gen"]),
                                tuple([script_name, 19, "main"]),
                            )
                        ],
                    )

                    # No awaited_by relationships expected
                    self.assertEqual(task.awaited_by, [])
            finally:
                _cleanup_sockets(client_socket, server_socket)

    @skip_if_not_supported
    @unittest.skipIf(
        sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
        "Test only runs on Linux with process_vm_readv support",
    )
    def test_async_gather_remote_stack_trace(self):
        port = find_unused_port()
        script = textwrap.dedent(
            f"""\
            import asyncio
            import time
            import sys
            import socket

            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.connect(('localhost', {port}))

            async def deep():
                await asyncio.sleep(0)
                sock.sendall(b"ready"); time.sleep(10_000)

            async def c1():
                await asyncio.sleep(0)
                await deep()

            async def c2():
                await asyncio.sleep(0)

            async def main():
                await asyncio.gather(c1(), c2())

            asyncio.run(main())
            """
        )

        with os_helper.temp_dir() as work_dir:
            script_dir = os.path.join(work_dir, "script_pkg")
            os.mkdir(script_dir)

            server_socket = _create_server_socket(port)
            script_name = _make_test_script(script_dir, "script", script)
            client_socket = None

            try:
                with _managed_subprocess([sys.executable, script_name]) as p:
                    client_socket, _ = server_socket.accept()
                    server_socket.close()
                    server_socket = None

                    response = _wait_for_signal(client_socket, b"ready")
                    self.assertIn(b"ready", response)

                    try:
                        stack_trace = get_async_stack_trace(p.pid)
                    except PermissionError:
                        self.skipTest(
                            "Insufficient permissions to read the stack trace"
                        )

                    # Check all tasks are present
                    tasks_names = [
                        task.task_name for task in stack_trace[0].awaited_by
                    ]
                    for task_name in ["Task-1", "Task-2"]:
                        self.assertIn(task_name, tasks_names)

                    # Check awaited_by relationships
                    relationships = self._get_awaited_by_relationships(
                        stack_trace
                    )
                    self.assertEqual(
                        relationships,
                        {
                            "Task-1": set(),
                            "Task-2": {"Task-1"},
                        },
                    )

                    # Check coroutine stacks
                    coroutine_stacks = self._extract_coroutine_stacks(
                        stack_trace
                    )
                    self.assertEqual(
                        coroutine_stacks,
                        {
                            "Task-1": [(tuple([script_name, 21, "main"]),)],
                            "Task-2": [
                                (
                                    tuple([script_name, 11, "deep"]),
                                    tuple([script_name, 15, "c1"]),
                                )
                            ],
                        },
                    )

                    # Check awaited_by coroutine stacks
                    id_to_task = self._get_task_id_map(stack_trace)
                    awaited_by_coroutine_stacks = {
                        task.task_name: sorted(
                            (
                                id_to_task[coro.task_name].task_name,
                                tuple(
                                    tuple(frame) for frame in coro.call_stack
                                ),
                            )
                            for coro in task.awaited_by
                        )
                        for task in stack_trace[0].awaited_by
                    }
                    self.assertEqual(
                        awaited_by_coroutine_stacks,
                        {
                            "Task-1": [],
                            "Task-2": [
                                ("Task-1", (tuple([script_name, 21, "main"]),))
                            ],
                        },
                    )
            finally:
                _cleanup_sockets(client_socket, server_socket)

    @skip_if_not_supported
    @unittest.skipIf(
        sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
        "Test only runs on Linux with process_vm_readv support",
    )
    def test_async_staggered_race_remote_stack_trace(self):
        port = find_unused_port()
        script = textwrap.dedent(
            f"""\
            import asyncio.staggered
            import time
            import sys
            import socket

            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.connect(('localhost', {port}))

            async def deep():
                await asyncio.sleep(0)
                sock.sendall(b"ready"); time.sleep(10_000)

            async def c1():
                await asyncio.sleep(0)
                await deep()

            async def c2():
                await asyncio.sleep(10_000)

            async def main():
                await asyncio.staggered.staggered_race(
                    [c1, c2],
                    delay=None,
                )

            asyncio.run(main())
            """
        )

        with os_helper.temp_dir() as work_dir:
            script_dir = os.path.join(work_dir, "script_pkg")
            os.mkdir(script_dir)

            server_socket = _create_server_socket(port)
            script_name = _make_test_script(script_dir, "script", script)
            client_socket = None

            try:
                with _managed_subprocess([sys.executable, script_name]) as p:
                    client_socket, _ = server_socket.accept()
                    server_socket.close()
                    server_socket = None

                    response = _wait_for_signal(client_socket, b"ready")
                    self.assertIn(b"ready", response)

                    try:
                        stack_trace = get_async_stack_trace(p.pid)
                    except PermissionError:
                        self.skipTest(
                            "Insufficient permissions to read the stack trace"
                        )

                    # Check all tasks are present
                    tasks_names = [
                        task.task_name for task in stack_trace[0].awaited_by
                    ]
                    for task_name in ["Task-1", "Task-2"]:
                        self.assertIn(task_name, tasks_names)

                    # Check awaited_by relationships
                    relationships = self._get_awaited_by_relationships(
                        stack_trace
                    )
                    self.assertEqual(
                        relationships,
                        {
                            "Task-1": set(),
                            "Task-2": {"Task-1"},
                        },
                    )

                    # Check coroutine stacks
                    coroutine_stacks = self._extract_coroutine_stacks(
                        stack_trace
                    )
                    self.assertEqual(
                        coroutine_stacks,
                        {
                            "Task-1": [
                                (
                                    tuple(
                                        [
                                            staggered.__file__,
                                            ANY,
                                            "staggered_race",
                                        ]
                                    ),
                                    tuple([script_name, 21, "main"]),
                                )
                            ],
                            "Task-2": [
                                (
                                    tuple([script_name, 11, "deep"]),
                                    tuple([script_name, 15, "c1"]),
                                    tuple(
                                        [
                                            staggered.__file__,
                                            ANY,
                                            "staggered_race.<locals>.run_one_coro",
                                        ]
                                    ),
                                )
                            ],
                        },
                    )

                    # Check awaited_by coroutine stacks
                    id_to_task = self._get_task_id_map(stack_trace)
                    awaited_by_coroutine_stacks = {
                        task.task_name: sorted(
                            (
                                id_to_task[coro.task_name].task_name,
                                tuple(
                                    tuple(frame) for frame in coro.call_stack
                                ),
                            )
                            for coro in task.awaited_by
                        )
                        for task in stack_trace[0].awaited_by
                    }
                    self.assertEqual(
                        awaited_by_coroutine_stacks,
                        {
                            "Task-1": [],
                            "Task-2": [
                                (
                                    "Task-1",
                                    (
                                        tuple(
                                            [
                                                staggered.__file__,
                                                ANY,
                                                "staggered_race",
                                            ]
                                        ),
                                        tuple([script_name, 21, "main"]),
                                    ),
                                )
                            ],
                        },
                    )
            finally:
                _cleanup_sockets(client_socket, server_socket)

    @skip_if_not_supported
    @unittest.skipIf(
        sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
        "Test only runs on Linux with process_vm_readv support",
    )
    def test_async_global_awaited_by(self):
        # Reduced from 1000 to 100 to avoid file descriptor exhaustion
        # when running tests in parallel (e.g., -j 20)
        NUM_TASKS = 100

        port = find_unused_port()
        script = textwrap.dedent(
            f"""\
            import asyncio
            import os
            import random
            import sys
            import socket
            from string import ascii_lowercase, digits
            from test.support import socket_helper, SHORT_TIMEOUT

            HOST = '127.0.0.1'
            PORT = socket_helper.find_unused_port()
            connections = 0

            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.connect(('localhost', {port}))

            class EchoServerProtocol(asyncio.Protocol):
                def connection_made(self, transport):
                    global connections
                    connections += 1
                    self.transport = transport

                def data_received(self, data):
                    self.transport.write(data)
                    self.transport.close()

            async def echo_client(message):
                reader, writer = await asyncio.open_connection(HOST, PORT)
                writer.write(message.encode())
                await writer.drain()

                data = await reader.read(100)
                assert message == data.decode()
                writer.close()
                await writer.wait_closed()
                sock.sendall(b"ready")
                await asyncio.sleep(SHORT_TIMEOUT)

            async def echo_client_spam(server):
                async with asyncio.TaskGroup() as tg:
                    while connections < {NUM_TASKS}:
                        msg = list(ascii_lowercase + digits)
                        random.shuffle(msg)
                        tg.create_task(echo_client("".join(msg)))
                        await asyncio.sleep(0)
                server.close()
                await server.wait_closed()

            async def main():
                loop = asyncio.get_running_loop()
                server = await loop.create_server(EchoServerProtocol, HOST, PORT)
                async with server:
                    async with asyncio.TaskGroup() as tg:
                        tg.create_task(server.serve_forever(), name="server task")
                        tg.create_task(echo_client_spam(server), name="echo client spam")

            asyncio.run(main())
            """
        )

        with os_helper.temp_dir() as work_dir:
            script_dir = os.path.join(work_dir, "script_pkg")
            os.mkdir(script_dir)

            server_socket = _create_server_socket(port)
            script_name = _make_test_script(script_dir, "script", script)
            client_socket = None

            try:
                with _managed_subprocess([sys.executable, script_name]) as p:
                    client_socket, _ = server_socket.accept()
                    server_socket.close()
                    server_socket = None

                    # Wait for NUM_TASKS "ready" signals
                    try:
                        _wait_for_n_signals(client_socket, b"ready", NUM_TASKS)
                    except RuntimeError as e:
                        self.fail(str(e))

                    try:
                        all_awaited_by = get_all_awaited_by(p.pid)
                    except PermissionError:
                        self.skipTest(
                            "Insufficient permissions to read the stack trace"
                        )

                    # Expected: a list of two elements: 1 thread, 1 interp
                    self.assertEqual(len(all_awaited_by), 2)
                    # Expected: a tuple with the thread ID and the awaited_by list
                    self.assertEqual(len(all_awaited_by[0]), 2)
                    # Expected: no tasks in the fallback per-interp task list
                    self.assertEqual(all_awaited_by[1], (0, []))

                    entries = all_awaited_by[0][1]
                    # Expected: at least NUM_TASKS pending tasks
                    self.assertGreaterEqual(len(entries), NUM_TASKS)

                    # Check the main task structure
                    main_stack = [
                        FrameInfo(
                            [taskgroups.__file__, ANY, "TaskGroup._aexit"]
                        ),
                        FrameInfo(
                            [taskgroups.__file__, ANY, "TaskGroup.__aexit__"]
                        ),
                        FrameInfo([script_name, 52, "main"]),
                    ]
                    self.assertIn(
                        TaskInfo(
                            [ANY, "Task-1", [CoroInfo([main_stack, ANY])], []]
                        ),
                        entries,
                    )
                    self.assertIn(
                        TaskInfo(
                            [
                                ANY,
                                "server task",
                                [
                                    CoroInfo(
                                        [
                                            [
                                                FrameInfo(
                                                    [
                                                        base_events.__file__,
                                                        ANY,
                                                        "Server.serve_forever",
                                                    ]
                                                )
                                            ],
                                            ANY,
                                        ]
                                    )
                                ],
                                [
                                    CoroInfo(
                                        [
                                            [
                                                FrameInfo(
                                                    [
                                                        taskgroups.__file__,
                                                        ANY,
                                                        "TaskGroup._aexit",
                                                    ]
                                                ),
                                                FrameInfo(
                                                    [
                                                        taskgroups.__file__,
                                                        ANY,
                                                        "TaskGroup.__aexit__",
                                                    ]
                                                ),
                                                FrameInfo(
                                                    [script_name, ANY, "main"]
                                                ),
                                            ],
                                            ANY,
                                        ]
                                    )
                                ],
                            ]
                        ),
                        entries,
                    )
                    self.assertIn(
                        TaskInfo(
                            [
                                ANY,
                                "Task-4",
                                [
                                    CoroInfo(
                                        [
                                            [
                                                FrameInfo(
                                                    [
                                                        tasks.__file__,
                                                        ANY,
                                                        "sleep",
                                                    ]
                                                ),
                                                FrameInfo(
                                                    [
                                                        script_name,
                                                        36,
                                                        "echo_client",
                                                    ]
                                                ),
                                            ],
                                            ANY,
                                        ]
                                    )
                                ],
                                [
                                    CoroInfo(
                                        [
                                            [
                                                FrameInfo(
                                                    [
                                                        taskgroups.__file__,
                                                        ANY,
                                                        "TaskGroup._aexit",
                                                    ]
                                                ),
                                                FrameInfo(
                                                    [
                                                        taskgroups.__file__,
                                                        ANY,
                                                        "TaskGroup.__aexit__",
                                                    ]
                                                ),
                                                FrameInfo(
                                                    [
                                                        script_name,
                                                        39,
                                                        "echo_client_spam",
                                                    ]
                                                ),
                                            ],
                                            ANY,
                                        ]
                                    )
                                ],
                            ]
                        ),
                        entries,
                    )

                    expected_awaited_by = [
                        CoroInfo(
                            [
                                [
                                    FrameInfo(
                                        [
                                            taskgroups.__file__,
                                            ANY,
                                            "TaskGroup._aexit",
                                        ]
                                    ),
                                    FrameInfo(
                                        [
                                            taskgroups.__file__,
                                            ANY,
                                            "TaskGroup.__aexit__",
                                        ]
                                    ),
                                    FrameInfo(
                                        [script_name, 39, "echo_client_spam"]
                                    ),
                                ],
                                ANY,
                            ]
                        )
                    ]
                    tasks_with_awaited = [
                        task
                        for task in entries
                        if task.awaited_by == expected_awaited_by
                    ]
                    self.assertGreaterEqual(len(tasks_with_awaited), NUM_TASKS)

                    # Final task should be from echo client spam (not on Windows)
                    if sys.platform != "win32":
                        self.assertEqual(
                            tasks_with_awaited[-1].awaited_by,
                            entries[-1].awaited_by,
                        )
            finally:
                _cleanup_sockets(client_socket, server_socket)

    @skip_if_not_supported
    @unittest.skipIf(
        sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
        "Test only runs on Linux with process_vm_readv support",
    )
    def test_self_trace(self):
        stack_trace = get_stack_trace(os.getpid())

        this_thread_stack = None
        for interpreter_info in stack_trace:
            for thread_info in interpreter_info.threads:
                if thread_info.thread_id == threading.get_native_id():
                    this_thread_stack = thread_info.frame_info
                    break
            if this_thread_stack:
                break

        self.assertIsNotNone(this_thread_stack)
        self.assertEqual(
            this_thread_stack[:2],
            [
                FrameInfo(
                    [
                        __file__,
                        get_stack_trace.__code__.co_firstlineno + 4,
                        "get_stack_trace",
                    ]
                ),
                FrameInfo(
                    [
                        __file__,
                        self.test_self_trace.__code__.co_firstlineno + 6,
                        "TestGetStackTrace.test_self_trace",
                    ]
                ),
            ],
        )

    @skip_if_not_supported
    @unittest.skipIf(
        sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
        "Test only runs on Linux with process_vm_readv support",
    )
    @requires_subinterpreters
    def test_subinterpreter_stack_trace(self):
        port = find_unused_port()

        import pickle

        subinterp_code = textwrap.dedent(f"""
            import socket
            import time

            def sub_worker():
                def nested_func():
                    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                    sock.connect(('localhost', {port}))
                    sock.sendall(b"ready:sub\\n")
                    time.sleep(10_000)
                nested_func()

            sub_worker()
        """).strip()

        pickled_code = pickle.dumps(subinterp_code)

        script = textwrap.dedent(
            f"""
            from concurrent import interpreters
            import time
            import sys
            import socket
            import threading

            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.connect(('localhost', {port}))

            def main_worker():
                sock.sendall(b"ready:main\\n")
                time.sleep(10_000)

            def run_subinterp():
                subinterp = interpreters.create()
                import pickle
                pickled_code = {pickled_code!r}
                subinterp_code = pickle.loads(pickled_code)
                subinterp.exec(subinterp_code)

            sub_thread = threading.Thread(target=run_subinterp)
            sub_thread.start()

            main_thread = threading.Thread(target=main_worker)
            main_thread.start()

            main_thread.join()
            sub_thread.join()
            """
        )

        with os_helper.temp_dir() as work_dir:
            script_dir = os.path.join(work_dir, "script_pkg")
            os.mkdir(script_dir)

            server_socket = _create_server_socket(port)
            script_name = _make_test_script(script_dir, "script", script)
            client_sockets = []

            try:
                with _managed_subprocess([sys.executable, script_name]) as p:
                    # Accept connections from both main and subinterpreter
                    responses = set()
                    while len(responses) < 2:
                        try:
                            client_socket, _ = server_socket.accept()
                            client_sockets.append(client_socket)
                            response = client_socket.recv(1024)
                            if b"ready:main" in response:
                                responses.add("main")
                            if b"ready:sub" in response:
                                responses.add("sub")
                        except socket.timeout:
                            break

                    server_socket.close()
                    server_socket = None

                    try:
                        stack_trace = get_stack_trace(p.pid)
                    except PermissionError:
                        self.skipTest(
                            "Insufficient permissions to read the stack trace"
                        )

                    # Verify we have at least one interpreter
                    self.assertGreaterEqual(len(stack_trace), 1)

                    # Look for main interpreter (ID 0) and subinterpreter (ID > 0)
                    main_interp = None
                    sub_interp = None
                    for interpreter_info in stack_trace:
                        if interpreter_info.interpreter_id == 0:
                            main_interp = interpreter_info
                        elif interpreter_info.interpreter_id > 0:
                            sub_interp = interpreter_info

                    self.assertIsNotNone(
                        main_interp, "Main interpreter should be present"
                    )

                    # Check main interpreter has expected stack trace
                    main_found = self._find_frame_in_trace(
                        [main_interp], lambda f: f.funcname == "main_worker"
                    )
                    self.assertIsNotNone(
                        main_found,
                        "Main interpreter should have main_worker in stack",
                    )

                    # If subinterpreter is present, check its stack trace
                    if sub_interp:
                        sub_found = self._find_frame_in_trace(
                            [sub_interp],
                            lambda f: f.funcname
                            in ("sub_worker", "nested_func"),
                        )
                        self.assertIsNotNone(
                            sub_found,
                            "Subinterpreter should have sub_worker or nested_func in stack",
                        )
            finally:
                _cleanup_sockets(*client_sockets, server_socket)

    @skip_if_not_supported
    @unittest.skipIf(
        sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
        "Test only runs on Linux with process_vm_readv support",
    )
    @requires_subinterpreters
    def test_multiple_subinterpreters_with_threads(self):
        port = find_unused_port()

        import pickle

        subinterp1_code = textwrap.dedent(f"""
            import socket
            import time
            import threading

            def worker1():
                def nested_func():
                    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                    sock.connect(('localhost', {port}))
                    sock.sendall(b"ready:sub1-t1\\n")
                    time.sleep(10_000)
                nested_func()

            def worker2():
                def nested_func():
                    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                    sock.connect(('localhost', {port}))
                    sock.sendall(b"ready:sub1-t2\\n")
                    time.sleep(10_000)
                nested_func()

            t1 = threading.Thread(target=worker1)
            t2 = threading.Thread(target=worker2)
            t1.start()
            t2.start()
            t1.join()
            t2.join()
        """).strip()

        subinterp2_code = textwrap.dedent(f"""
            import socket
            import time
            import threading

            def worker1():
                def nested_func():
                    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                    sock.connect(('localhost', {port}))
                    sock.sendall(b"ready:sub2-t1\\n")
                    time.sleep(10_000)
                nested_func()

            def worker2():
                def nested_func():
                    sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
                    sock.connect(('localhost', {port}))
                    sock.sendall(b"ready:sub2-t2\\n")
                    time.sleep(10_000)
                nested_func()

            t1 = threading.Thread(target=worker1)
            t2 = threading.Thread(target=worker2)
            t1.start()
            t2.start()
            t1.join()
            t2.join()
        """).strip()

        pickled_code1 = pickle.dumps(subinterp1_code)
        pickled_code2 = pickle.dumps(subinterp2_code)

        script = textwrap.dedent(
            f"""
            from concurrent import interpreters
            import time
            import sys
            import socket
            import threading

            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.connect(('localhost', {port}))

            def main_worker():
                sock.sendall(b"ready:main\\n")
                time.sleep(10_000)

            def run_subinterp1():
                subinterp = interpreters.create()
                import pickle
                pickled_code = {pickled_code1!r}
                subinterp_code = pickle.loads(pickled_code)
                subinterp.exec(subinterp_code)

            def run_subinterp2():
                subinterp = interpreters.create()
                import pickle
                pickled_code = {pickled_code2!r}
                subinterp_code = pickle.loads(pickled_code)
                subinterp.exec(subinterp_code)

            sub1_thread = threading.Thread(target=run_subinterp1)
            sub2_thread = threading.Thread(target=run_subinterp2)
            sub1_thread.start()
            sub2_thread.start()

            main_thread = threading.Thread(target=main_worker)
            main_thread.start()

            main_thread.join()
            sub1_thread.join()
            sub2_thread.join()
            """
        )

        with os_helper.temp_dir() as work_dir:
            script_dir = os.path.join(work_dir, "script_pkg")
            os.mkdir(script_dir)

            server_socket = _create_server_socket(port, backlog=5)
            script_name = _make_test_script(script_dir, "script", script)
            client_sockets = []

            try:
                with _managed_subprocess([sys.executable, script_name]) as p:
                    # Accept connections from main and all subinterpreter threads
                    expected_responses = {
                        "ready:main",
                        "ready:sub1-t1",
                        "ready:sub1-t2",
                        "ready:sub2-t1",
                        "ready:sub2-t2",
                    }
                    responses = set()

                    while len(responses) < 5:
                        try:
                            client_socket, _ = server_socket.accept()
                            client_sockets.append(client_socket)
                            response = client_socket.recv(1024)
                            response_str = response.decode().strip()
                            if response_str in expected_responses:
                                responses.add(response_str)
                        except socket.timeout:
                            break

                    server_socket.close()
                    server_socket = None

                    try:
                        stack_trace = get_stack_trace(p.pid)
                    except PermissionError:
                        self.skipTest(
                            "Insufficient permissions to read the stack trace"
                        )

                    # Verify we have multiple interpreters
                    self.assertGreaterEqual(len(stack_trace), 2)

                    # Count interpreters by ID
                    interpreter_ids = {
                        interp.interpreter_id for interp in stack_trace
                    }
                    self.assertIn(
                        0,
                        interpreter_ids,
                        "Main interpreter should be present",
                    )
                    self.assertGreaterEqual(len(interpreter_ids), 3)

                    # Count total threads
                    total_threads = sum(
                        len(interp.threads) for interp in stack_trace
                    )
                    self.assertGreaterEqual(total_threads, 5)

                    # Look for expected function names
                    all_funcnames = set()
                    for interpreter_info in stack_trace:
                        for thread_info in interpreter_info.threads:
                            for frame in thread_info.frame_info:
                                all_funcnames.add(frame.funcname)

                    expected_funcs = {
                        "main_worker",
                        "worker1",
                        "worker2",
                        "nested_func",
                    }
                    found_funcs = expected_funcs.intersection(all_funcnames)
                    self.assertGreater(len(found_funcs), 0)
            finally:
                _cleanup_sockets(*client_sockets, server_socket)

    @skip_if_not_supported
    @unittest.skipIf(
        sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
        "Test only runs on Linux with process_vm_readv support",
    )
    @requires_gil_enabled("Free threaded builds don't have an 'active thread'")
    def test_only_active_thread(self):
        port = find_unused_port()
        script = textwrap.dedent(
            f"""\
            import time, sys, socket, threading

            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.connect(('localhost', {port}))

            def worker_thread(name, barrier, ready_event):
                barrier.wait()
                ready_event.wait()
                time.sleep(10_000)

            def main_work():
                sock.sendall(b"working\\n")
                count = 0
                while count < 100000000:
                    count += 1
                    if count % 10000000 == 0:
                        pass
                sock.sendall(b"done\\n")

            num_threads = 3
            barrier = threading.Barrier(num_threads + 1)
            ready_event = threading.Event()

            threads = []
            for i in range(num_threads):
                t = threading.Thread(target=worker_thread, args=(f"Worker-{{i}}", barrier, ready_event))
                t.start()
                threads.append(t)

            barrier.wait()
            sock.sendall(b"ready\\n")
            ready_event.set()
            main_work()
            """
        )

        with os_helper.temp_dir() as work_dir:
            script_dir = os.path.join(work_dir, "script_pkg")
            os.mkdir(script_dir)

            server_socket = _create_server_socket(port)
            script_name = _make_test_script(script_dir, "script", script)
            client_socket = None

            try:
                with _managed_subprocess([sys.executable, script_name]) as p:
                    client_socket, _ = server_socket.accept()
                    server_socket.close()
                    server_socket = None

                    # Wait for ready and working signals
                    _wait_for_signal(client_socket, [b"ready", b"working"])

                    try:
                        # Get stack trace with all threads
                        unwinder_all = RemoteUnwinder(p.pid, all_threads=True)
                        for _ in range(MAX_TRIES):
                            all_traces = unwinder_all.get_stack_trace()
                            found = self._find_frame_in_trace(
                                all_traces,
                                lambda f: f.funcname == "main_work"
                                and f.lineno > 12,
                            )
                            if found:
                                break
                            time.sleep(0.1)
                        else:
                            self.fail(
                                "Main thread did not start its busy work on time"
                            )

                        # Get stack trace with only GIL holder
                        unwinder_gil = RemoteUnwinder(
                            p.pid, only_active_thread=True
                        )
                        gil_traces = unwinder_gil.get_stack_trace()
                    except PermissionError:
                        self.skipTest(
                            "Insufficient permissions to read the stack trace"
                        )

                    # Count threads
                    total_threads = sum(
                        len(interp.threads) for interp in all_traces
                    )
                    self.assertGreater(total_threads, 1)

                    total_gil_threads = sum(
                        len(interp.threads) for interp in gil_traces
                    )
                    self.assertEqual(total_gil_threads, 1)

                    # Get the GIL holder thread ID
                    gil_thread_id = None
                    for interpreter_info in gil_traces:
                        if interpreter_info.threads:
                            gil_thread_id = interpreter_info.threads[
                                0
                            ].thread_id
                            break

                    # Get all thread IDs
                    all_thread_ids = []
                    for interpreter_info in all_traces:
                        for thread_info in interpreter_info.threads:
                            all_thread_ids.append(thread_info.thread_id)

                    self.assertIn(gil_thread_id, all_thread_ids)
            finally:
                _cleanup_sockets(client_socket, server_socket)


class TestUnsupportedPlatformHandling(unittest.TestCase):
    @unittest.skipIf(
        sys.platform in ("linux", "darwin", "win32"),
        "Test only runs on unsupported platforms (not Linux, macOS, or Windows)",
    )
    @unittest.skipIf(
        sys.platform == "android", "Android raises Linux-specific exception"
    )
    def test_unsupported_platform_error(self):
        with self.assertRaises(RuntimeError) as cm:
            RemoteUnwinder(os.getpid())

        self.assertIn(
            "Reading the PyRuntime section is not supported on this platform",
            str(cm.exception),
        )


class TestDetectionOfThreadStatus(RemoteInspectionTestBase):
    def _run_thread_status_test(self, mode, check_condition):
        """
        Common pattern for thread status detection tests.

        Args:
            mode: Profiling mode (PROFILING_MODE_CPU, PROFILING_MODE_GIL, etc.)
            check_condition: Function(statuses, sleeper_tid, busy_tid) -> bool
        """
        port = find_unused_port()
        script = textwrap.dedent(
            f"""\
            import time, sys, socket, threading
            import os

            sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
            sock.connect(('localhost', {port}))

            def sleeper():
                tid = threading.get_native_id()
                sock.sendall(f'ready:sleeper:{{tid}}\\n'.encode())
                time.sleep(10000)

            def busy():
                tid = threading.get_native_id()
                sock.sendall(f'ready:busy:{{tid}}\\n'.encode())
                x = 0
                while True:
                    x = x + 1
                time.sleep(0.5)

            t1 = threading.Thread(target=sleeper)
            t2 = threading.Thread(target=busy)
            t1.start()
            t2.start()
            sock.sendall(b'ready:main\\n')
            t1.join()
            t2.join()
            sock.close()
            """
        )

        with os_helper.temp_dir() as work_dir:
            script_dir = os.path.join(work_dir, "script_pkg")
            os.mkdir(script_dir)

            server_socket = _create_server_socket(port)
            script_name = _make_test_script(
                script_dir, "thread_status_script", script
            )
            client_socket = None

            try:
                with _managed_subprocess([sys.executable, script_name]) as p:
                    client_socket, _ = server_socket.accept()
                    server_socket.close()
                    server_socket = None

                    # Wait for all ready signals and parse TIDs
                    response = _wait_for_signal(
                        client_socket,
                        [b"ready:main", b"ready:sleeper", b"ready:busy"],
                    )

                    sleeper_tid = None
                    busy_tid = None
                    for line in response.split(b"\n"):
                        if line.startswith(b"ready:sleeper:"):
                            try:
                                sleeper_tid = int(line.split(b":")[-1])
                            except (ValueError, IndexError):
                                pass
                        elif line.startswith(b"ready:busy:"):
                            try:
                                busy_tid = int(line.split(b":")[-1])
                            except (ValueError, IndexError):
                                pass

                    self.assertIsNotNone(
                        sleeper_tid, "Sleeper thread id not received"
                    )
                    self.assertIsNotNone(
                        busy_tid, "Busy thread id not received"
                    )

                    # Sample until we see expected thread states
                    statuses = {}
                    try:
                        unwinder = RemoteUnwinder(
                            p.pid,
                            all_threads=True,
                            mode=mode,
                            skip_non_matching_threads=False,
                        )
                        for _ in range(MAX_TRIES):
                            traces = unwinder.get_stack_trace()
                            statuses = self._get_thread_statuses(traces)

                            if check_condition(
                                statuses, sleeper_tid, busy_tid
                            ):
                                break
                            time.sleep(0.5)
                    except PermissionError:
                        self.skipTest(
                            "Insufficient permissions to read the stack trace"
                        )

                    return statuses, sleeper_tid, busy_tid
            finally:
                _cleanup_sockets(client_socket, server_socket)

    @unittest.skipIf(
        sys.platform not in ("linux", "darwin", "win32"),
        "Test only runs on supported platforms (Linux, macOS, or Windows)",
    )
    @unittest.skipIf(
        sys.platform == "android", "Android raises Linux-specific exception"
    )
    def test_thread_status_detection(self):
        def check_cpu_status(statuses, sleeper_tid, busy_tid):
            return (
                sleeper_tid in statuses
                and busy_tid in statuses
                and not (statuses[sleeper_tid] & THREAD_STATUS_ON_CPU)
                and (statuses[busy_tid] & THREAD_STATUS_ON_CPU)
            )

        statuses, sleeper_tid, busy_tid = self._run_thread_status_test(
            PROFILING_MODE_CPU, check_cpu_status
        )

        self.assertIn(sleeper_tid, statuses)
        self.assertIn(busy_tid, statuses)
        self.assertFalse(
            statuses[sleeper_tid] & THREAD_STATUS_ON_CPU,
            "Sleeper thread should be off CPU",
        )
        self.assertTrue(
            statuses[busy_tid] & THREAD_STATUS_ON_CPU,
            "Busy thread should be on CPU",
        )

    @unittest.skipIf(
        sys.platform not in ("linux", "darwin", "win32"),
        "Test only runs on supported platforms (Linux, macOS, or Windows)",
    )
    @unittest.skipIf(
        sys.platform == "android", "Android raises Linux-specific exception"
    )
    def test_thread_status_gil_detection(self):
        def check_gil_status(statuses, sleeper_tid, busy_tid):
            return (
                sleeper_tid in statuses
                and busy_tid in statuses
                and not (statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL)
                and (statuses[busy_tid] & THREAD_STATUS_HAS_GIL)
            )

        statuses, sleeper_tid, busy_tid = self._run_thread_status_test(
            PROFILING_MODE_GIL, check_gil_status
        )

        self.assertIn(sleeper_tid, statuses)
        self.assertIn(busy_tid, statuses)
        self.assertFalse(
            statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL,
            "Sleeper thread should not have GIL",
        )
        self.assertTrue(
            statuses[busy_tid] & THREAD_STATUS_HAS_GIL,
            "Busy thread should have GIL",
        )

    @unittest.skipIf(
        sys.platform not in ("linux", "darwin", "win32"),
        "Test only runs on supported platforms (Linux, macOS, or Windows)",
    )
    @unittest.skipIf(
        sys.platform == "android", "Android raises Linux-specific exception"
    )
    def test_thread_status_all_mode_detection(self):
        port = find_unused_port()
        script = textwrap.dedent(
            f"""\
            import socket
            import threading
            import time
            import sys

            def sleeper_thread():
                conn = socket.create_connection(("localhost", {port}))
                conn.sendall(b"sleeper:" + str(threading.get_native_id()).encode())
                while True:
                    time.sleep(1)

            def busy_thread():
                conn = socket.create_connection(("localhost", {port}))
                conn.sendall(b"busy:" + str(threading.get_native_id()).encode())
                while True:
                    sum(range(100000))

            t1 = threading.Thread(target=sleeper_thread)
            t2 = threading.Thread(target=busy_thread)
            t1.start()
            t2.start()
            t1.join()
            t2.join()
            """
        )

        with os_helper.temp_dir() as tmp_dir:
            script_file = make_script(tmp_dir, "script", script)
            server_socket = _create_server_socket(port, backlog=2)
            client_sockets = []

            try:
                with _managed_subprocess(
                    [sys.executable, script_file],
                ) as p:
                    sleeper_tid = None
                    busy_tid = None

                    # Receive thread IDs from the child process
                    for _ in range(2):
                        client_socket, _ = server_socket.accept()
                        client_sockets.append(client_socket)
                        line = client_socket.recv(1024)
                        if line:
                            if line.startswith(b"sleeper:"):
                                try:
                                    sleeper_tid = int(line.split(b":")[-1])
                                except (ValueError, IndexError):
                                    pass
                            elif line.startswith(b"busy:"):
                                try:
                                    busy_tid = int(line.split(b":")[-1])
                                except (ValueError, IndexError):
                                    pass

                    server_socket.close()
                    server_socket = None

                    statuses = {}
                    try:
                        unwinder = RemoteUnwinder(
                            p.pid,
                            all_threads=True,
                            mode=PROFILING_MODE_ALL,
                            skip_non_matching_threads=False,
                        )
                        for _ in range(MAX_TRIES):
                            traces = unwinder.get_stack_trace()
                            statuses = self._get_thread_statuses(traces)

                            # Check ALL mode provides both GIL and CPU info
                            if (
                                sleeper_tid in statuses
                                and busy_tid in statuses
                                and not (
                                    statuses[sleeper_tid]
                                    & THREAD_STATUS_ON_CPU
                                )
                                and not (
                                    statuses[sleeper_tid]
                                    & THREAD_STATUS_HAS_GIL
                                )
                                and (statuses[busy_tid] & THREAD_STATUS_ON_CPU)
                                and (
                                    statuses[busy_tid] & THREAD_STATUS_HAS_GIL
                                )
                            ):
                                break
                            time.sleep(0.5)
                    except PermissionError:
                        self.skipTest(
                            "Insufficient permissions to read the stack trace"
                        )

                    self.assertIsNotNone(
                        sleeper_tid, "Sleeper thread id not received"
                    )
                    self.assertIsNotNone(
                        busy_tid, "Busy thread id not received"
                    )
                    self.assertIn(sleeper_tid, statuses)
                    self.assertIn(busy_tid, statuses)

                    # Sleeper: off CPU, no GIL
                    self.assertFalse(
                        statuses[sleeper_tid] & THREAD_STATUS_ON_CPU,
                        "Sleeper should be off CPU",
                    )
                    self.assertFalse(
                        statuses[sleeper_tid] & THREAD_STATUS_HAS_GIL,
                        "Sleeper should not have GIL",
                    )

                    # Busy: on CPU, has GIL
                    self.assertTrue(
                        statuses[busy_tid] & THREAD_STATUS_ON_CPU,
                        "Busy should be on CPU",
                    )
                    self.assertTrue(
                        statuses[busy_tid] & THREAD_STATUS_HAS_GIL,
                        "Busy should have GIL",
                    )
            finally:
                _cleanup_sockets(*client_sockets, server_socket)


class TestFrameCaching(RemoteInspectionTestBase):
    """Test that frame caching produces correct results.

    Uses socket-based synchronization for deterministic testing.
    All tests verify cache reuse via object identity checks (assertIs).
    """

    @contextmanager
    def _target_process(self, script_body):
        """Context manager for running a target process with socket sync."""
        port = find_unused_port()
        script = f"""\
import socket
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
sock.connect(('localhost', {port}))
{textwrap.dedent(script_body)}
"""

        with os_helper.temp_dir() as work_dir:
            script_dir = os.path.join(work_dir, "script_pkg")
            os.mkdir(script_dir)

            server_socket = _create_server_socket(port)
            script_name = _make_test_script(script_dir, "script", script)
            client_socket = None

            try:
                with _managed_subprocess([sys.executable, script_name]) as p:
                    client_socket, _ = server_socket.accept()
                    server_socket.close()
                    server_socket = None

                    def make_unwinder(cache_frames=True):
                        return RemoteUnwinder(
                            p.pid, all_threads=True, cache_frames=cache_frames
                        )

                    yield p, client_socket, make_unwinder

            except PermissionError:
                self.skipTest(
                    "Insufficient permissions to read the stack trace"
                )
            finally:
                _cleanup_sockets(client_socket, server_socket)

    def _get_frames_with_retry(self, unwinder, required_funcs):
        """Get frames containing required_funcs, with retry for transient errors."""
        for _ in range(MAX_TRIES):
            with contextlib.suppress(OSError, RuntimeError):
                traces = unwinder.get_stack_trace()
                for interp in traces:
                    for thread in interp.threads:
                        funcs = {f.funcname for f in thread.frame_info}
                        if required_funcs.issubset(funcs):
                            return thread.frame_info
            time.sleep(0.1)
        return None

    def _sample_frames(
        self,
        client_socket,
        unwinder,
        wait_signal,
        send_ack,
        required_funcs,
        expected_frames=1,
    ):
        """Wait for signal, sample frames with retry until required funcs present, send ack."""
        _wait_for_signal(client_socket, wait_signal)
        frames = None
        for _ in range(MAX_TRIES):
            frames = self._get_frames_with_retry(unwinder, required_funcs)
            if frames and len(frames) >= expected_frames:
                break
            time.sleep(0.1)
        client_socket.sendall(send_ack)
        return frames

    @skip_if_not_supported
    @unittest.skipIf(
        sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
        "Test only runs on Linux with process_vm_readv support",
    )
    def test_cache_hit_same_stack(self):
        """Test that consecutive samples reuse cached parent frame objects.

        The current frame (index 0) is always re-read from memory to get
        updated line numbers, so it may be a different object. Parent frames
        (index 1+) should be identical objects from cache.
        """
        script_body = """\
            def level3():
                sock.sendall(b"sync1")
                sock.recv(16)
                sock.sendall(b"sync2")
                sock.recv(16)
                sock.sendall(b"sync3")
                sock.recv(16)

            def level2():
                level3()

            def level1():
                level2()

            level1()
            """

        with self._target_process(script_body) as (
            p,
            client_socket,
            make_unwinder,
        ):
            unwinder = make_unwinder(cache_frames=True)
            expected = {"level1", "level2", "level3"}

            frames1 = self._sample_frames(
                client_socket, unwinder, b"sync1", b"ack", expected
            )
            frames2 = self._sample_frames(
                client_socket, unwinder, b"sync2", b"ack", expected
            )
            frames3 = self._sample_frames(
                client_socket, unwinder, b"sync3", b"done", expected
            )

        self.assertIsNotNone(frames1)
        self.assertIsNotNone(frames2)
        self.assertIsNotNone(frames3)
        self.assertEqual(len(frames1), len(frames2))
        self.assertEqual(len(frames2), len(frames3))

        # Current frame (index 0) is always re-read, so check value equality
        self.assertEqual(frames1[0].funcname, frames2[0].funcname)
        self.assertEqual(frames2[0].funcname, frames3[0].funcname)

        # Parent frames (index 1+) must be identical objects (cache reuse)
        for i in range(1, len(frames1)):
            f1, f2, f3 = frames1[i], frames2[i], frames3[i]
            self.assertIs(
                f1, f2, f"Frame {i}: samples 1-2 must be same object"
            )
            self.assertIs(
                f2, f3, f"Frame {i}: samples 2-3 must be same object"
            )

    @skip_if_not_supported
    @unittest.skipIf(
        sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
        "Test only runs on Linux with process_vm_readv support",
    )
    def test_line_number_updates_in_same_frame(self):
        """Test that line numbers are correctly updated when execution moves within a function.

        When the profiler samples at different points within the same function,
        it must report the correct line number for each sample, not stale cached values.
        """
        script_body = """\
            def outer():
                inner()

            def inner():
                sock.sendall(b"line_a"); sock.recv(16)
                sock.sendall(b"line_b"); sock.recv(16)
                sock.sendall(b"line_c"); sock.recv(16)
                sock.sendall(b"line_d"); sock.recv(16)

            outer()
            """

        with self._target_process(script_body) as (
            p,
            client_socket,
            make_unwinder,
        ):
            unwinder = make_unwinder(cache_frames=True)

            frames_a = self._sample_frames(
                client_socket, unwinder, b"line_a", b"ack", {"inner"}
            )
            frames_b = self._sample_frames(
                client_socket, unwinder, b"line_b", b"ack", {"inner"}
            )
            frames_c = self._sample_frames(
                client_socket, unwinder, b"line_c", b"ack", {"inner"}
            )
            frames_d = self._sample_frames(
                client_socket, unwinder, b"line_d", b"done", {"inner"}
            )

        self.assertIsNotNone(frames_a)
        self.assertIsNotNone(frames_b)
        self.assertIsNotNone(frames_c)
        self.assertIsNotNone(frames_d)

        # Get the 'inner' frame from each sample (should be index 0)
        inner_a = frames_a[0]
        inner_b = frames_b[0]
        inner_c = frames_c[0]
        inner_d = frames_d[0]

        self.assertEqual(inner_a.funcname, "inner")
        self.assertEqual(inner_b.funcname, "inner")
        self.assertEqual(inner_c.funcname, "inner")
        self.assertEqual(inner_d.funcname, "inner")

        # Line numbers must be different and increasing (execution moves forward)
        self.assertLess(
            inner_a.lineno, inner_b.lineno, "Line B should be after line A"
        )
        self.assertLess(
            inner_b.lineno, inner_c.lineno, "Line C should be after line B"
        )
        self.assertLess(
            inner_c.lineno, inner_d.lineno, "Line D should be after line C"
        )

    @skip_if_not_supported
    @unittest.skipIf(
        sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
        "Test only runs on Linux with process_vm_readv support",
    )
    def test_cache_invalidation_on_return(self):
        """Test cache invalidation when stack shrinks (function returns)."""
        script_body = """\
            def inner():
                sock.sendall(b"at_inner")
                sock.recv(16)

            def outer():
                inner()
                sock.sendall(b"at_outer")
                sock.recv(16)

            outer()
            """

        with self._target_process(script_body) as (
            p,
            client_socket,
            make_unwinder,
        ):
            unwinder = make_unwinder(cache_frames=True)

            frames_deep = self._sample_frames(
                client_socket,
                unwinder,
                b"at_inner",
                b"ack",
                {"inner", "outer"},
            )
            frames_shallow = self._sample_frames(
                client_socket, unwinder, b"at_outer", b"done", {"outer"}
            )

        self.assertIsNotNone(frames_deep)
        self.assertIsNotNone(frames_shallow)

        funcs_deep = [f.funcname for f in frames_deep]
        funcs_shallow = [f.funcname for f in frames_shallow]

        self.assertIn("inner", funcs_deep)
        self.assertIn("outer", funcs_deep)
        self.assertNotIn("inner", funcs_shallow)
        self.assertIn("outer", funcs_shallow)

    @skip_if_not_supported
    @unittest.skipIf(
        sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
        "Test only runs on Linux with process_vm_readv support",
    )
    def test_cache_invalidation_on_call(self):
        """Test cache invalidation when stack grows (new function called)."""
        script_body = """\
            def deeper():
                sock.sendall(b"at_deeper")
                sock.recv(16)

            def middle():
                sock.sendall(b"at_middle")
                sock.recv(16)
                deeper()

            def top():
                middle()

            top()
            """

        with self._target_process(script_body) as (
            p,
            client_socket,
            make_unwinder,
        ):
            unwinder = make_unwinder(cache_frames=True)

            frames_before = self._sample_frames(
                client_socket,
                unwinder,
                b"at_middle",
                b"ack",
                {"middle", "top"},
            )
            frames_after = self._sample_frames(
                client_socket,
                unwinder,
                b"at_deeper",
                b"done",
                {"deeper", "middle", "top"},
            )

        self.assertIsNotNone(frames_before)
        self.assertIsNotNone(frames_after)

        funcs_before = [f.funcname for f in frames_before]
        funcs_after = [f.funcname for f in frames_after]

        self.assertIn("middle", funcs_before)
        self.assertIn("top", funcs_before)
        self.assertNotIn("deeper", funcs_before)

        self.assertIn("deeper", funcs_after)
        self.assertIn("middle", funcs_after)
        self.assertIn("top", funcs_after)

    @skip_if_not_supported
    @unittest.skipIf(
        sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
        "Test only runs on Linux with process_vm_readv support",
    )
    def test_partial_stack_reuse(self):
        """Test that unchanged bottom frames are reused when top changes (A→B→C to A→B→D)."""
        script_body = """\
            def func_c():
                sock.sendall(b"at_c")
                sock.recv(16)

            def func_d():
                sock.sendall(b"at_d")
                sock.recv(16)

            def func_b():
                func_c()
                func_d()

            def func_a():
                func_b()

            func_a()
            """

        with self._target_process(script_body) as (
            p,
            client_socket,
            make_unwinder,
        ):
            unwinder = make_unwinder(cache_frames=True)

            # Sample at C: stack is A→B→C
            frames_c = self._sample_frames(
                client_socket,
                unwinder,
                b"at_c",
                b"ack",
                {"func_a", "func_b", "func_c"},
            )
            # Sample at D: stack is A→B→D (C returned, D called)
            frames_d = self._sample_frames(
                client_socket,
                unwinder,
                b"at_d",
                b"done",
                {"func_a", "func_b", "func_d"},
            )

        self.assertIsNotNone(frames_c)
        self.assertIsNotNone(frames_d)

        # Find func_a and func_b frames in both samples
        def find_frame(frames, funcname):
            for f in frames:
                if f.funcname == funcname:
                    return f
            return None

        frame_a_in_c = find_frame(frames_c, "func_a")
        frame_b_in_c = find_frame(frames_c, "func_b")
        frame_a_in_d = find_frame(frames_d, "func_a")
        frame_b_in_d = find_frame(frames_d, "func_b")

        self.assertIsNotNone(frame_a_in_c)
        self.assertIsNotNone(frame_b_in_c)
        self.assertIsNotNone(frame_a_in_d)
        self.assertIsNotNone(frame_b_in_d)

        # The bottom frames (A, B) should be the SAME objects (cache reuse)
        self.assertIs(
            frame_a_in_c,
            frame_a_in_d,
            "func_a frame should be reused from cache",
        )
        self.assertIs(
            frame_b_in_c,
            frame_b_in_d,
            "func_b frame should be reused from cache",
        )

    @skip_if_not_supported
    @unittest.skipIf(
        sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
        "Test only runs on Linux with process_vm_readv support",
    )
    def test_recursive_frames(self):
        """Test caching with same function appearing multiple times (recursion)."""
        script_body = """\
            def recurse(n):
                if n <= 0:
                    sock.sendall(b"sync1")
                    sock.recv(16)
                    sock.sendall(b"sync2")
                    sock.recv(16)
                else:
                    recurse(n - 1)

            recurse(5)
            """

        with self._target_process(script_body) as (
            p,
            client_socket,
            make_unwinder,
        ):
            unwinder = make_unwinder(cache_frames=True)

            frames1 = self._sample_frames(
                client_socket, unwinder, b"sync1", b"ack", {"recurse"}
            )
            frames2 = self._sample_frames(
                client_socket, unwinder, b"sync2", b"done", {"recurse"}
            )

        self.assertIsNotNone(frames1)
        self.assertIsNotNone(frames2)

        # Should have multiple "recurse" frames (6 total: recurse(5) down to recurse(0))
        recurse_count = sum(1 for f in frames1 if f.funcname == "recurse")
        self.assertEqual(recurse_count, 6, "Should have 6 recursive frames")

        self.assertEqual(len(frames1), len(frames2))

        # Current frame (index 0) is re-read, check value equality
        self.assertEqual(frames1[0].funcname, frames2[0].funcname)

        # Parent frames (index 1+) should be identical objects (cache reuse)
        for i in range(1, len(frames1)):
            self.assertIs(
                frames1[i],
                frames2[i],
                f"Frame {i}: recursive frames must be same object",
            )

    @skip_if_not_supported
    @unittest.skipIf(
        sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
        "Test only runs on Linux with process_vm_readv support",
    )
    def test_cache_vs_no_cache_equivalence(self):
        """Test that cache_frames=True and cache_frames=False produce equivalent results."""
        script_body = """\
            def level3():
                sock.sendall(b"ready"); sock.recv(16)

            def level2():
                level3()

            def level1():
                level2()

            level1()
            """

        with self._target_process(script_body) as (
            p,
            client_socket,
            make_unwinder,
        ):
            _wait_for_signal(client_socket, b"ready")

            # Sample with cache
            unwinder_cache = make_unwinder(cache_frames=True)
            frames_cached = self._get_frames_with_retry(
                unwinder_cache, {"level1", "level2", "level3"}
            )

            # Sample without cache
            unwinder_no_cache = make_unwinder(cache_frames=False)
            frames_no_cache = self._get_frames_with_retry(
                unwinder_no_cache, {"level1", "level2", "level3"}
            )

            client_socket.sendall(b"done")

        self.assertIsNotNone(frames_cached)
        self.assertIsNotNone(frames_no_cache)

        # Same number of frames
        self.assertEqual(len(frames_cached), len(frames_no_cache))

        # Same function names in same order
        funcs_cached = [f.funcname for f in frames_cached]
        funcs_no_cache = [f.funcname for f in frames_no_cache]
        self.assertEqual(funcs_cached, funcs_no_cache)

        # Same line numbers
        lines_cached = [f.lineno for f in frames_cached]
        lines_no_cache = [f.lineno for f in frames_no_cache]
        self.assertEqual(lines_cached, lines_no_cache)

    @skip_if_not_supported
    @unittest.skipIf(
        sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
        "Test only runs on Linux with process_vm_readv support",
    )
    def test_cache_per_thread_isolation(self):
        """Test that frame cache is per-thread and cache invalidation works independently."""
        script_body = """\
            import threading

            lock = threading.Lock()

            def sync(msg):
                with lock:
                    sock.sendall(msg + b"\\n")
                    sock.recv(1)

            # Thread 1 functions
            def baz1():
                sync(b"t1:baz1")

            def bar1():
                baz1()

            def blech1():
                sync(b"t1:blech1")

            def foo1():
                bar1()  # Goes down to baz1, syncs
                blech1()  # Returns up, goes down to blech1, syncs

            # Thread 2 functions
            def baz2():
                sync(b"t2:baz2")

            def bar2():
                baz2()

            def blech2():
                sync(b"t2:blech2")

            def foo2():
                bar2()  # Goes down to baz2, syncs
                blech2()  # Returns up, goes down to blech2, syncs

            t1 = threading.Thread(target=foo1)
            t2 = threading.Thread(target=foo2)
            t1.start()
            t2.start()
            t1.join()
            t2.join()
            """

        with self._target_process(script_body) as (
            p,
            client_socket,
            make_unwinder,
        ):
            unwinder = make_unwinder(cache_frames=True)

            # Message dispatch table: signal -> required functions for that thread
            dispatch = {
                b"t1:baz1": {"baz1", "bar1", "foo1"},
                b"t2:baz2": {"baz2", "bar2", "foo2"},
                b"t1:blech1": {"blech1", "foo1"},
                b"t2:blech2": {"blech2", "foo2"},
            }

            # Track results for each sync point
            results = {}

            # Process 4 sync points (order depends on thread scheduling)
            buffer = _wait_for_signal(client_socket, b"\n")
            for i in range(4):
                # Extract first message from buffer
                msg, sep, buffer = buffer.partition(b"\n")
                self.assertIn(msg, dispatch, f"Unexpected message: {msg!r}")

                # Sample frames for the thread at this sync point
                required_funcs = dispatch[msg]
                frames = self._get_frames_with_retry(unwinder, required_funcs)
                self.assertIsNotNone(frames, f"Thread not found for {msg!r}")
                results[msg] = [f.funcname for f in frames]

                # Release thread and wait for next message (if not last)
                client_socket.sendall(b"k")
                if i < 3:
                    buffer += _wait_for_signal(client_socket, b"\n")

            # Validate Phase 1: baz snapshots
            t1_baz = results.get(b"t1:baz1")
            t2_baz = results.get(b"t2:baz2")
            self.assertIsNotNone(t1_baz, "Missing t1:baz1 snapshot")
            self.assertIsNotNone(t2_baz, "Missing t2:baz2 snapshot")

            # Thread 1 at baz1: should have foo1->bar1->baz1
            self.assertIn("baz1", t1_baz)
            self.assertIn("bar1", t1_baz)
            self.assertIn("foo1", t1_baz)
            self.assertNotIn("blech1", t1_baz)
            # No cross-contamination
            self.assertNotIn("baz2", t1_baz)
            self.assertNotIn("bar2", t1_baz)
            self.assertNotIn("foo2", t1_baz)

            # Thread 2 at baz2: should have foo2->bar2->baz2
            self.assertIn("baz2", t2_baz)
            self.assertIn("bar2", t2_baz)
            self.assertIn("foo2", t2_baz)
            self.assertNotIn("blech2", t2_baz)
            # No cross-contamination
            self.assertNotIn("baz1", t2_baz)
            self.assertNotIn("bar1", t2_baz)
            self.assertNotIn("foo1", t2_baz)

            # Validate Phase 2: blech snapshots (cache invalidation test)
            t1_blech = results.get(b"t1:blech1")
            t2_blech = results.get(b"t2:blech2")
            self.assertIsNotNone(t1_blech, "Missing t1:blech1 snapshot")
            self.assertIsNotNone(t2_blech, "Missing t2:blech2 snapshot")

            # Thread 1 at blech1: bar1/baz1 should be GONE (cache invalidated)
            self.assertIn("blech1", t1_blech)
            self.assertIn("foo1", t1_blech)
            self.assertNotIn(
                "bar1", t1_blech, "Cache not invalidated: bar1 still present"
            )
            self.assertNotIn(
                "baz1", t1_blech, "Cache not invalidated: baz1 still present"
            )
            # No cross-contamination
            self.assertNotIn("blech2", t1_blech)

            # Thread 2 at blech2: bar2/baz2 should be GONE (cache invalidated)
            self.assertIn("blech2", t2_blech)
            self.assertIn("foo2", t2_blech)
            self.assertNotIn(
                "bar2", t2_blech, "Cache not invalidated: bar2 still present"
            )
            self.assertNotIn(
                "baz2", t2_blech, "Cache not invalidated: baz2 still present"
            )
            # No cross-contamination
            self.assertNotIn("blech1", t2_blech)

    @skip_if_not_supported
    @unittest.skipIf(
        sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
        "Test only runs on Linux with process_vm_readv support",
    )
    def test_new_unwinder_with_stale_last_profiled_frame(self):
        """Test that a new unwinder returns complete stack when cache lookup misses."""
        script_body = """\
            def level4():
                sock.sendall(b"sync1")
                sock.recv(16)
                sock.sendall(b"sync2")
                sock.recv(16)

            def level3():
                level4()

            def level2():
                level3()

            def level1():
                level2()

            level1()
            """

        with self._target_process(script_body) as (
            p,
            client_socket,
            make_unwinder,
        ):
            expected = {"level1", "level2", "level3", "level4"}

            # First unwinder samples - this sets last_profiled_frame in target
            unwinder1 = make_unwinder(cache_frames=True)
            frames1 = self._sample_frames(
                client_socket, unwinder1, b"sync1", b"ack", expected
            )

            # Create NEW unwinder (empty cache) and sample
            # The target still has last_profiled_frame set from unwinder1
            unwinder2 = make_unwinder(cache_frames=True)
            frames2 = self._sample_frames(
                client_socket, unwinder2, b"sync2", b"done", expected
            )

        self.assertIsNotNone(frames1)
        self.assertIsNotNone(frames2)

        funcs1 = [f.funcname for f in frames1]
        funcs2 = [f.funcname for f in frames2]

        # Both should have all levels
        for level in ["level1", "level2", "level3", "level4"]:
            self.assertIn(level, funcs1, f"{level} missing from first sample")
            self.assertIn(level, funcs2, f"{level} missing from second sample")

        # Should have same stack depth
        self.assertEqual(
            len(frames1),
            len(frames2),
            "New unwinder should return complete stack despite stale last_profiled_frame",
        )

    @skip_if_not_supported
    @unittest.skipIf(
        sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
        "Test only runs on Linux with process_vm_readv support",
    )
    def test_cache_exhaustion(self):
        """Test cache works when frame limit (1024) is exceeded.

        FRAME_CACHE_MAX_FRAMES=1024. With 1100 recursive frames,
        the cache can't store all of them but should still work.
        """
        # Use 1100 to exceed FRAME_CACHE_MAX_FRAMES=1024
        depth = 1100
        script_body = f"""\
import sys
sys.setrecursionlimit(2000)

def recurse(n):
    if n <= 0:
        sock.sendall(b"ready")
        sock.recv(16)  # wait for ack
        sock.sendall(b"ready2")
        sock.recv(16)  # wait for done
        return
    recurse(n - 1)

recurse({depth})
"""

        with self._target_process(script_body) as (
            p,
            client_socket,
            make_unwinder,
        ):
            unwinder_cache = make_unwinder(cache_frames=True)
            unwinder_no_cache = make_unwinder(cache_frames=False)

            frames_cached = self._sample_frames(
                client_socket,
                unwinder_cache,
                b"ready",
                b"ack",
                {"recurse"},
                expected_frames=1102,
            )
            # Sample again with no cache for comparison
            frames_no_cache = self._sample_frames(
                client_socket,
                unwinder_no_cache,
                b"ready2",
                b"done",
                {"recurse"},
                expected_frames=1102,
            )

        self.assertIsNotNone(frames_cached)
        self.assertIsNotNone(frames_no_cache)

        # Both should have many recurse frames (> 1024 limit)
        cached_count = [f.funcname for f in frames_cached].count("recurse")
        no_cache_count = [f.funcname for f in frames_no_cache].count("recurse")

        self.assertGreater(
            cached_count, 1000, "Should have >1000 recurse frames"
        )
        self.assertGreater(
            no_cache_count, 1000, "Should have >1000 recurse frames"
        )

        # Both modes should produce same frame count
        self.assertEqual(
            len(frames_cached),
            len(frames_no_cache),
            "Cache exhaustion should not affect stack completeness",
        )

    @skip_if_not_supported
    @unittest.skipIf(
        sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
        "Test only runs on Linux with process_vm_readv support",
    )
    def test_get_stats(self):
        """Test that get_stats() returns statistics when stats=True."""
        script_body = """\
            sock.sendall(b"ready")
            sock.recv(16)
            """

        with self._target_process(script_body) as (p, client_socket, _):
            unwinder = RemoteUnwinder(p.pid, all_threads=True, stats=True)
            _wait_for_signal(client_socket, b"ready")

            # Take a sample
            unwinder.get_stack_trace()

            stats = unwinder.get_stats()
            client_socket.sendall(b"done")

        # Verify expected keys exist
        expected_keys = [
            "total_samples",
            "frame_cache_hits",
            "frame_cache_misses",
            "frame_cache_partial_hits",
            "frames_read_from_cache",
            "frames_read_from_memory",
            "frame_cache_hit_rate",
        ]
        for key in expected_keys:
            self.assertIn(key, stats)

        self.assertEqual(stats["total_samples"], 1)

    @skip_if_not_supported
    @unittest.skipIf(
        sys.platform == "linux" and not PROCESS_VM_READV_SUPPORTED,
        "Test only runs on Linux with process_vm_readv support",
    )
    def test_get_stats_disabled_raises(self):
        """Test that get_stats() raises RuntimeError when stats=False."""
        script_body = """\
            sock.sendall(b"ready")
            sock.recv(16)
            """

        with self._target_process(script_body) as (p, client_socket, _):
            unwinder = RemoteUnwinder(
                p.pid, all_threads=True
            )  # stats=False by default
            _wait_for_signal(client_socket, b"ready")

            with self.assertRaises(RuntimeError):
                unwinder.get_stats()

            client_socket.sendall(b"done")


if __name__ == "__main__":
    unittest.main()
